Skip to content

InitContext, part 4 - Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values #984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 17 commits into
base: py/init-prior-uniform
Choose a base branch
from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Jul 10, 2025

Part 1: Adding hasvalue and getvalue to AbstractPPL
Part 2: Removing hasvalue and getvalue from DynamicPPL
Part 3: Introducing InitContext and init!!

This is part 4/N of #967.


In Part 3 we introduced InitContext. This PR makes use of the functionality in there to replace a bunch of code that no longer needs to exist:

  • setval_and_resample! followed by model evaluation: This process was used for predict and returned, to manually store certain values in the VarInfo, which would be used in the subsequent model evaluation. We can now do this in a single step using ParamsInit.
  • initialize_values!!: very similar to the above. It would manually set values inside the varinfo, and then it would trigger an extra model evaluation to update the logp field. Again, this is directly replaced with ParamsInit.
  • evaluate_and_sample!!: direct one-to-one replacement with init!!.

There is one API change associated with this: the initial_params kwarg to sample must now be an AbstractInitStrategy. It's still optional (it will usually default to PriorInit). However, there are two implications:

  • initial_params cannot be a vector of parameters anymore. It must be ParamsInit(::NamedTuple) OR ParamsInit(::AbstractDict{VarName}).
  • Because ParamsInit expects values in unlinked space, initial_params must always be specified in unlinked space. Previously, initial_params would have to be specified in a way that matched the linking status of the underlying varinfo.

I consider both of these to be a major win for clarity. (One might argue that vectors are more convenient. Sure, you can get a vector with vi[:], but now you can just do values_as(vi, Dict{VarName,Any}) instead.)

Closes

Closes #774
Closes #797
DOES NOT CLOSE #983
Closes TuringLang/Turing.jl#2476
Closes TuringLang/Turing.jl#1775

Copy link
Contributor

github-actions bot commented Jul 10, 2025

Benchmark Report for Commit 5025592

Computer Information

Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                  9.4 |                 1.3 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                540.0 |                35.3 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                296.4 |                59.1 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |                864.2 |                29.3 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               4743.3 |                22.7 |
|           Smorgasbord |       201 | reversediff |             typed |   true |                801.6 |                36.2 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |                778.6 |                11.9 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               4920.5 |                41.6 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |                709.1 |                 8.8 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              59454.5 |                38.7 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               5803.6 |                10.2 |
|               Dynamic |        10 |    mooncake |             typed |   true |                 99.8 |                26.9 |
|              Submodel |         1 |    mooncake |             typed |   true |                 12.6 |                21.6 |
|                   LDA |        12 | reversediff |             typed |   true |                813.9 |                 1.8 |

Comment on lines 126 to 134
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`
_, varinfo = last(
DynamicPPL.init!!(
rng,
model,
varinfo,
DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()),
),
)
Copy link
Member Author

@penelopeysm penelopeysm Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that, if the chain does not store varnames inside its info field, chain_sample_to_varname_dict will fail.

I don't consider this to be a problem right now because every chain obtained via Turing's sample() will contain varnames:

https://github.com/TuringLang/Turing.jl/blob/1aa95ac91a115569c742bab74f7b751ed1450309/src/mcmc/Inference.jl#L288-L290

So this is only a problem if you manually construct a chain and try to call predict on it, which I think is a highly unlikely workflow (and I'm happy to wait for people to complain if it fails). There are a few places where this happened in the test suite. I fixed them all by adding the appropriate varname dictionary.

However, it's obviously ugly. The only good way around this is, as I suggested before, to rework MCMCChains.jl.

@penelopeysm penelopeysm changed the title Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values InitContext, part 4 - Use init!! to replace evaluate_and_sample!!, predict, returned, and initialize_values Jul 10, 2025
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch 2 times, most recently from 025aa8b to b55c1e1 Compare July 10, 2025 14:24
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch 5 times, most recently from b72c3bf to 92d3542 Compare July 10, 2025 15:57
@penelopeysm penelopeysm mentioned this pull request Jul 10, 2025
20 tasks
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch 4 times, most recently from 7438b23 to d55d378 Compare July 10, 2025 16:56
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch 3 times, most recently from 12d93e5 to 7a8e7e3 Compare July 10, 2025 17:47
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch from 7e38bbe to 1d8bceb Compare July 19, 2025 22:37
@penelopeysm penelopeysm force-pushed the py/actually-use-init branch from 1d8bceb to 2edcd10 Compare July 20, 2025 00:59
Copy link
Contributor

DynamicPPL.jl documentation for PR #984 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR984/

Copy link

codecov bot commented Jul 20, 2025

Codecov Report

Attention: Patch coverage is 87.64045% with 11 lines in your changes missing coverage. Please review.

Project coverage is 80.44%. Comparing base (2f7eba8) to head (5025592).

Files with missing lines Patch % Lines
src/simple_varinfo.jl 45.45% 6 Missing ⚠️
src/test_utils/contexts.jl 85.71% 4 Missing ⚠️
src/test_utils/model_interface.jl 0.00% 1 Missing ⚠️
Additional details and impacted files
@@                    Coverage Diff                    @@
##           py/init-prior-uniform     #984      +/-   ##
=========================================================
- Coverage                  82.20%   80.44%   -1.77%     
=========================================================
  Files                         39       39              
  Lines                       4052     3998      -54     
=========================================================
- Hits                        3331     3216     -115     
- Misses                       721      782      +61     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant